{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3718.40it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2381.54it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 444321.96it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 66270.86it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 42664.36it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 330095.71it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 373073.88it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 1847527.98it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3771.43it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:07<00:00, 2367.53it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 69352.98it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 97079.34it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 74474.07it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 334659.48it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 414870.15it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 43066.26it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 32836.35it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 3487.84it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2442.70it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 67645.17it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 44196.33it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 76927.43it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 122242.02it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 333332.07it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 35858.80it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 28605.98it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 3830.90it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2071.67it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 70304.99it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 85642.29it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 81757.54it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 145883.48it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 372769.40it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 35687.87it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 33140.49it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "dev_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7faf25c75f98>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7faf264499b0>,\n",
       " 'vocab_size': 16674,\n",
       " 'embedding_input_dim': 16674,\n",
       " 'input_shapes': [(10,), (100,)]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=10))\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           (None, 10, 300)      5002200     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 10, 1)        300         embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "match_histogram (InputLayer)    (None, 10, 30)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "attention_mask (Lambda)         (None, 10, 1)        0           dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10, 10)       310         match_histogram[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "attention_probs (Lambda)        (None, 10, 1)        0           attention_mask[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 10, 1)        11          dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 1, 1)         0           attention_probs[0][0]            \n",
      "                                                                 dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 1)            0           dot_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 1)            2           flatten_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 5,002,823\n",
      "Trainable params: 5,002,823\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "bin_size = 30\n",
    "model = mz.models.DRMM()\n",
    "model.params.update(preprocessor.context)\n",
    "model.params['input_shapes'] = [[10,], [10, bin_size,]]\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mask_value'] = 0\n",
    "model.params['embedding_output_dim'] = glove_embedding.output_dim\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 10\n",
    "model.params['mlp_num_fan_out'] = 1\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])\n",
    "# normalize the word embedding for fast histogram generating.\n",
    "l2_norm = np.sqrt((embedding_matrix*embedding_matrix).sum(axis=1))\n",
    "embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]\n",
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist_callback = mz.data_generator.callbacks.Histogram(embedding_matrix, bin_size=30, hist_mode='LCH')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_generator = mz.DataGenerator(test_pack_processed, mode='point', callbacks=[hist_callback])\n",
    "pred_x, pred_y = pred_generator[:]\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, \n",
    "                                           x=pred_x, \n",
    "                                           y=pred_y, \n",
    "                                           once_every=1, \n",
    "                                           batch_size=len(pred_y),\n",
    "                                           model_save_path='./drmm_pretrained_model/'\n",
    "                                          )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num batches: 255\n"
     ]
    }
   ],
   "source": [
    "train_generator = mz.DataGenerator(train_pack_processed, mode='pair', num_dup=5, num_neg=10, batch_size=20, \n",
    "                                   callbacks=[hist_callback])\n",
    "print('num batches:', len(train_generator))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "255/255 [==============================] - 29s 113ms/step - loss: 2.2520\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.565716344425806 - normalized_discounted_cumulative_gain@5(0.0): 0.6337418608669659 - mean_average_precision(0.0): 0.5867500331707677\n",
      "Epoch 2/30\n",
      "255/255 [==============================] - 44s 172ms/step - loss: 1.9129\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5985553932954132 - normalized_discounted_cumulative_gain@5(0.0): 0.6538053305162407 - mean_average_precision(0.0): 0.6119736749640002\n",
      "Epoch 3/30\n",
      "255/255 [==============================] - 44s 172ms/step - loss: 1.5735\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6009435695818752 - normalized_discounted_cumulative_gain@5(0.0): 0.6639328074555028 - mean_average_precision(0.0): 0.6141421880590333\n",
      "Epoch 4/30\n",
      "255/255 [==============================] - 44s 174ms/step - loss: 1.3388\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.60611654032328 - normalized_discounted_cumulative_gain@5(0.0): 0.657932990992144 - mean_average_precision(0.0): 0.6164241968035815\n",
      "Epoch 5/30\n",
      "255/255 [==============================] - 44s 173ms/step - loss: 1.2257\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6038327788753987 - normalized_discounted_cumulative_gain@5(0.0): 0.658207699559794 - mean_average_precision(0.0): 0.6165270742594192\n",
      "Epoch 6/30\n",
      "255/255 [==============================] - 45s 175ms/step - loss: 1.1705\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5996133695926983 - normalized_discounted_cumulative_gain@5(0.0): 0.6492549188602507 - mean_average_precision(0.0): 0.610038711779037\n",
      "Epoch 7/30\n",
      "255/255 [==============================] - 49s 194ms/step - loss: 1.1245\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.596951218733747 - normalized_discounted_cumulative_gain@5(0.0): 0.6537028205494899 - mean_average_precision(0.0): 0.6118594078618916\n",
      "Epoch 8/30\n",
      "255/255 [==============================] - 49s 192ms/step - loss: 1.0965\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.607384889011673 - normalized_discounted_cumulative_gain@5(0.0): 0.6591796919495273 - mean_average_precision(0.0): 0.6198677092456238\n",
      "Epoch 9/30\n",
      "255/255 [==============================] - 50s 196ms/step - loss: 1.0739\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6009170916196322 - normalized_discounted_cumulative_gain@5(0.0): 0.6585982395889274 - mean_average_precision(0.0): 0.61618597649167\n",
      "Epoch 10/30\n",
      "255/255 [==============================] - 50s 196ms/step - loss: 1.0617\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6033747474109487 - normalized_discounted_cumulative_gain@5(0.0): 0.6578329756640785 - mean_average_precision(0.0): 0.6179567226451466\n",
      "Epoch 11/30\n",
      "255/255 [==============================] - 51s 199ms/step - loss: 1.0452\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5997577679210652 - normalized_discounted_cumulative_gain@5(0.0): 0.653391959851057 - mean_average_precision(0.0): 0.6107675669593364\n",
      "Epoch 12/30\n",
      "255/255 [==============================] - 50s 197ms/step - loss: 1.0463\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6099707901807351 - normalized_discounted_cumulative_gain@5(0.0): 0.6576795760995855 - mean_average_precision(0.0): 0.6190800467908035\n",
      "Epoch 13/30\n",
      "255/255 [==============================] - 48s 189ms/step - loss: 1.0142\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6060017367789287 - normalized_discounted_cumulative_gain@5(0.0): 0.6561934627591615 - mean_average_precision(0.0): 0.6164463654573828\n",
      "Epoch 14/30\n",
      "255/255 [==============================] - 48s 190ms/step - loss: 1.0166\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5978152368225886 - normalized_discounted_cumulative_gain@5(0.0): 0.6505786036930443 - mean_average_precision(0.0): 0.61033087775462\n",
      "Epoch 15/30\n",
      "255/255 [==============================] - 49s 192ms/step - loss: 0.9978\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5922250615836083 - normalized_discounted_cumulative_gain@5(0.0): 0.6426732309894715 - mean_average_precision(0.0): 0.6022889901471488\n",
      "Epoch 16/30\n",
      "255/255 [==============================] - 49s 191ms/step - loss: 1.0025\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5750131412934796 - normalized_discounted_cumulative_gain@5(0.0): 0.6352853673494635 - mean_average_precision(0.0): 0.5900977589685756\n",
      "Epoch 17/30\n",
      "255/255 [==============================] - 51s 199ms/step - loss: 0.9862\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5758042381295027 - normalized_discounted_cumulative_gain@5(0.0): 0.6346575981857837 - mean_average_precision(0.0): 0.5901888624427178\n",
      "Epoch 18/30\n",
      "255/255 [==============================] - 49s 193ms/step - loss: 0.9855\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.579621264998872 - normalized_discounted_cumulative_gain@5(0.0): 0.6347686906366622 - mean_average_precision(0.0): 0.5884915666561875\n",
      "Epoch 19/30\n",
      "255/255 [==============================] - 49s 192ms/step - loss: 0.9766\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5828950659531305 - normalized_discounted_cumulative_gain@5(0.0): 0.6372592629514049 - mean_average_precision(0.0): 0.5918219888593219\n",
      "Epoch 20/30\n",
      "255/255 [==============================] - 48s 190ms/step - loss: 0.9680\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5825563351056543 - normalized_discounted_cumulative_gain@5(0.0): 0.6391007734169843 - mean_average_precision(0.0): 0.5938890573634718\n",
      "Epoch 21/30\n",
      "255/255 [==============================] - 49s 191ms/step - loss: 0.9508\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.579128022658977 - normalized_discounted_cumulative_gain@5(0.0): 0.6395338143562516 - mean_average_precision(0.0): 0.5941989088930982\n",
      "Epoch 22/30\n",
      "255/255 [==============================] - 50s 196ms/step - loss: 0.9522\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5757112036272017 - normalized_discounted_cumulative_gain@5(0.0): 0.6293793465362543 - mean_average_precision(0.0): 0.583203117231657\n",
      "Epoch 23/30\n",
      "255/255 [==============================] - 52s 203ms/step - loss: 0.9465\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5712280953991604 - normalized_discounted_cumulative_gain@5(0.0): 0.633607426512819 - mean_average_precision(0.0): 0.583670818529751\n",
      "Epoch 24/30\n",
      "255/255 [==============================] - 51s 200ms/step - loss: 0.9434\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.580920539337657 - normalized_discounted_cumulative_gain@5(0.0): 0.641607241519095 - mean_average_precision(0.0): 0.5944285471787385\n",
      "Epoch 25/30\n",
      "255/255 [==============================] - 51s 198ms/step - loss: 0.9410\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5804534006493391 - normalized_discounted_cumulative_gain@5(0.0): 0.6365914442161427 - mean_average_precision(0.0): 0.5909732454825581\n",
      "Epoch 26/30\n",
      "255/255 [==============================] - 48s 190ms/step - loss: 0.9288\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5706608764817882 - normalized_discounted_cumulative_gain@5(0.0): 0.6287030620891625 - mean_average_precision(0.0): 0.5812929865633114\n",
      "Epoch 27/30\n",
      "255/255 [==============================] - 49s 193ms/step - loss: 0.9390\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5609829297013442 - normalized_discounted_cumulative_gain@5(0.0): 0.6256839520090155 - mean_average_precision(0.0): 0.5787581434265507\n",
      "Epoch 28/30\n",
      "255/255 [==============================] - 49s 191ms/step - loss: 0.9276\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5719190626216555 - normalized_discounted_cumulative_gain@5(0.0): 0.6307496688555204 - mean_average_precision(0.0): 0.5817515076104938\n",
      "Epoch 29/30\n",
      "255/255 [==============================] - 49s 193ms/step - loss: 0.9183\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5795917195993303 - normalized_discounted_cumulative_gain@5(0.0): 0.629216010151653 - mean_average_precision(0.0): 0.5871250977359159\n",
      "Epoch 30/30\n",
      "255/255 [==============================] - 51s 200ms/step - loss: 0.9114\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5699846325745904 - normalized_discounted_cumulative_gain@5(0.0): 0.6288458929530589 - mean_average_precision(0.0): 0.581158143460481\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.1248043 ],\n",
       "       [-0.7405751 ],\n",
       "       [ 1.2434225 ],\n",
       "       [-3.161619  ],\n",
       "       [-1.5212955 ],\n",
       "       [-1.3774216 ],\n",
       "       [-0.43242887],\n",
       "       [-3.8191142 ],\n",
       "       [-2.490344  ],\n",
       "       [ 1.9016179 ]], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drmm_model = mz.load_model('./drmm_pretrained_model/16')\n",
    "test_generator = mz.DataGenerator(data_pack=dev_pack_processed[:10], mode='point', callbacks=[hist_callback])\n",
    "test_x, test_y = test_generator[:]\n",
    "prediction = drmm_model.predict(test_x)\n",
    "prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "shutil.rmtree('./drmm_pretrained_model/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
